/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 * All rights reserved.
 *
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include <folly/portability/GMock.h>
#include <folly/portability/GTest.h>
#include <proxygen/lib/http/session/test/MockQuicSocketDriver.h>
#include <proxygen/lib/http/webtransport/QuicWebTransport.h>
#include <proxygen/lib/http/webtransport/test/Mocks.h>
#include <quic/priority/HTTPPriorityQueue.h>

using namespace proxygen;
using namespace proxygen::test;
using namespace testing;
using quic::MockQuicSocketDriver;

class MockDeliveryCallback : public WebTransport::ByteEventCallback {
 public:
  MOCK_METHOD(void, onByteEvent, (quic::StreamId, uint64_t), (noexcept));

  MOCK_METHOD(void,
              onByteEventCanceled,
              (quic::StreamId, uint64_t),
              (noexcept));
};

namespace {
const uint32_t WT_ERROR_1 = 19;
const uint32_t WT_ERROR_2 = 77;
} // namespace

class QuicWebTransportTest : public Test {
 protected:
  void SetUp() override {
    handler_ = std::make_unique<StrictMock<MockWebTransportHandler>>();
    transport_ = std::make_unique<QuicWebTransport>(socketDriver_.getSocket());
    transport_->setHandler(handler_.get());
    socketDriver_.setMaxUniStreams(100);
  }

  void TearDown() override {
    if (transport_) {
      webTransport()->closeSession();
    }
  }

  WebTransport* webTransport() {
    return transport_.get();
  }

  folly::EventBase eventBase_;
  MockQuicSocketDriver socketDriver_{
      &eventBase_,
      nullptr,
      nullptr,
      MockQuicSocketDriver::TransportEnum::SERVER,
      "alpn1"};
  std::unique_ptr<StrictMock<MockWebTransportHandler>> handler_;
  std::unique_ptr<QuicWebTransport> transport_;
};

TEST_F(QuicWebTransportTest, PeerUniStream) {
  EXPECT_CALL(*handler_, onNewUniStream(_))
      .WillOnce([this](WebTransport::StreamReadHandle* readHandle) {
        readHandle->awaitNextRead(
            &eventBase_,
            [](WebTransport::StreamReadHandle*,
               uint64_t,
               folly::Try<WebTransport::StreamData> data) {
              EXPECT_FALSE(data.hasException());
              EXPECT_EQ(data->data->computeChainDataLength(), 5);
              EXPECT_TRUE(data->fin);
            });
      });
  socketDriver_.addReadEvent(2, folly::IOBuf::copyBuffer("hello"), true);
  eventBase_.loop();
}

TEST_F(QuicWebTransportTest, PeerBidiStream) {
  EXPECT_CALL(*handler_, onNewBidiStream(_))
      .WillOnce([this](WebTransport::BidiStreamHandle bidiHandle) {
        auto wtImpl = dynamic_cast<WebTransportImpl*>(transport_.get());
        ASSERT_NE(wtImpl, nullptr);
        wtImpl->onMaxData(1000);
        bidiHandle.writeHandle->writeStreamData(
            folly::IOBuf::copyBuffer("hello"), true, nullptr);
        bidiHandle.readHandle->awaitNextRead(
            &eventBase_,
            [](WebTransport::StreamReadHandle*,
               uint64_t,
               folly::Try<WebTransport::StreamData> data) {
              EXPECT_FALSE(data.hasException());
              EXPECT_TRUE(data->fin);
            });
      });
  socketDriver_.addReadEvent(0, nullptr, true);
  eventBase_.loop();
  EXPECT_EQ(socketDriver_.streams_[0].writeBuf.chainLength(), 5);
}

TEST_F(QuicWebTransportTest, NewUniStream) {
  auto handle = webTransport()->createUniStream();
  EXPECT_TRUE(handle.hasValue());
  auto id = handle.value()->getID();
  handle.value()->resetStream(WT_ERROR_1);
  eventBase_.loop();
  EXPECT_EQ(socketDriver_.streams_[id].error, WT_ERROR_1);
}

TEST_F(QuicWebTransportTest, NewBidiStream) {
  auto handle = webTransport()->createBidiStream();
  EXPECT_TRUE(handle.hasValue());
  auto id = handle->readHandle->getID();
  handle->readHandle->stopSending(WT_ERROR_1);
  EXPECT_EQ(socketDriver_.streams_[id].error, WT_ERROR_1);
  handle->writeHandle->resetStream(WT_ERROR_2);
  EXPECT_EQ(socketDriver_.streams_[id].error, WT_ERROR_2);
}

TEST_F(QuicWebTransportTest, Datagram) {
  webTransport()->sendDatagram(folly::IOBuf::copyBuffer("hello"));
  EXPECT_EQ(socketDriver_.outDatagrams_.front().chainLength(), 5);
  socketDriver_.addDatagram(folly::IOBuf::copyBuffer("world!"));
  EXPECT_CALL(*handler_, onDatagram(_))
      .WillOnce([](std::unique_ptr<folly::IOBuf> datagram) {
        EXPECT_EQ(datagram->computeChainDataLength(), 6);
      });
  socketDriver_.addDatagramsAvailableReadEvent();
  eventBase_.loop();
}

TEST_F(QuicWebTransportTest, OnStopSending) {
  auto handle = webTransport()->createUniStream();
  EXPECT_TRUE(handle.hasValue());
  auto id = handle.value()->getID();
  socketDriver_.addStopSending(id, WT_ERROR_1);
  eventBase_.loopOnce();
  auto res = handle.value()->writeStreamData(nullptr, true, nullptr);
  EXPECT_TRUE(res.hasError());
  EXPECT_EQ(res.error(), WebTransport::ErrorCode::STOP_SENDING);
  EXPECT_EQ(handle.value()->exception()->error, WT_ERROR_1);
}

TEST_F(QuicWebTransportTest, ConnectionError) {
  auto handle = webTransport()->createUniStream();
  EXPECT_TRUE(handle.hasValue());
  EXPECT_CALL(*handler_, onSessionEnd(_)).WillOnce([](const auto& err) {
    EXPECT_EQ(*err, WT_ERROR_1);
  });
  socketDriver_.deliverConnectionError(
      quic::QuicError(quic::ApplicationErrorCode(WT_ERROR_1), "peer close"));
  eventBase_.loop();
}

// Test that onSessionEnd can safely destroy the transport (use-after-free
// prevention). The handler callback nullifies handler_ before invocation,
// so destroying `this` inside the callback is safe.
TEST_F(QuicWebTransportTest, ConnectionErrorHandlerDestroysTransport) {
  bool handlerCalled = false;
  EXPECT_CALL(*handler_, onSessionEnd(_))
      .WillOnce([this, &handlerCalled](const auto& err) {
        EXPECT_EQ(*err, WT_ERROR_1);
        handlerCalled = true;
        // Simulate the handler destroying the transport during the callback.
        // This would cause a use-after-free if handler_ was accessed after
        // onSessionEnd returns, but we clear handler_ before calling.
        transport_.reset();
      });
  socketDriver_.deliverConnectionError(
      quic::QuicError(quic::ApplicationErrorCode(WT_ERROR_1), "peer close"));
  eventBase_.loop();
  EXPECT_TRUE(handlerCalled);
  // transport_ is now null because the handler destroyed it
  EXPECT_EQ(transport_, nullptr);
}

TEST_F(QuicWebTransportTest, SetPriority) {
  auto handle = webTransport()->createUniStream();
  EXPECT_TRUE(handle.hasValue());
  socketDriver_.expectSetPriority(
      handle.value()->getID(), quic::HTTPPriorityQueue::Priority(1, false, 1));
  handle.value()->setPriority(1, 1, false);
  handle.value()->writeStreamData(nullptr, true, nullptr);
  eventBase_.loop();
}

TEST_F(QuicWebTransportTest, WriteWithDeliveryCallback) {
  auto handle = webTransport()->createUniStream();
  EXPECT_TRUE(handle.hasValue());
  auto mockCallback = std::make_unique<StrictMock<MockDeliveryCallback>>();

  folly::StringPiece data = "test data";

  uint64_t expectedStreamId = handle.value()->getID();
  uint32_t expectedOffset = data.size();
  EXPECT_CALL(*mockCallback, onByteEvent(expectedStreamId, expectedOffset))
      .Times(1);

  auto wtImpl = dynamic_cast<WebTransportImpl*>(transport_.get());
  ASSERT_NE(wtImpl, nullptr);
  wtImpl->onMaxData(1000);

  handle.value()->writeStreamData(
      folly::IOBuf::copyBuffer(data), true, mockCallback.get());
  // The MockQuicSocketDriver automatically simulates the delivery of data
  // written to the QUIC socket.
  eventBase_.loop();
}

TEST_F(QuicWebTransportTest, CloseTransportCancelsReadTokens) {
  WebTransport::StreamReadHandle* readHandle1 = nullptr;
  WebTransport::StreamReadHandle* readHandle2 = nullptr;
  folly::CancellationToken token1;
  folly::CancellationToken token2;

  // Create two read handles from incoming uni streams
  EXPECT_CALL(*handler_, onNewUniStream(_))
      .WillOnce([&](WebTransport::StreamReadHandle* handle) {
        readHandle1 = handle;
        token1 = handle->getCancelToken();
      })
      .WillOnce([&](WebTransport::StreamReadHandle* handle) {
        readHandle2 = handle;
        token2 = handle->getCancelToken();
      });

  // Trigger two incoming uni streams
  socketDriver_.addReadEvent(2, nullptr, false);
  socketDriver_.addReadEvent(6, nullptr, false);
  eventBase_.loopOnce();

  ASSERT_NE(readHandle1, nullptr);
  ASSERT_NE(readHandle2, nullptr);
  EXPECT_FALSE(token1.isCancellationRequested());
  EXPECT_FALSE(token2.isCancellationRequested());

  // Invoke read on the first handle (will be pending since no data)
  auto future1 = readHandle1->readStreamData();
  bool readErrorReceived = false;

  std::move(future1)
      .via(&eventBase_)
      .thenTry([&](folly::Try<WebTransport::StreamData> result) {
        EXPECT_TRUE(result.hasException());
        // Verify the exception is a WebTransport::Exception
        auto* ex = result.tryGetExceptionObject<WebTransport::Exception>();
        ASSERT_NE(ex, nullptr);
        // Verify the error code matches the connection error
        // This is broken behavior but it's what mvfst currently does
        EXPECT_EQ(ex->error, WT_ERROR_1);
        readErrorReceived = true;
      });

  // Close the transport by delivering a connection error
  EXPECT_CALL(*handler_, onSessionEnd(_));
  socketDriver_.deliverConnectionError(quic::QuicError(
      quic::ApplicationErrorCode(WT_ERROR_1), "connection close"));
  eventBase_.loop();

  // Verify both tokens are cancelled
  EXPECT_TRUE(token1.isCancellationRequested());
  EXPECT_TRUE(token2.isCancellationRequested());

  // Verify the pending read received an error
  EXPECT_TRUE(readErrorReceived);
}

TEST_F(QuicWebTransportTest, DestructorTerminatesOpenStreams) {
  WebTransport::StreamReadHandle* uniReadHandle = nullptr;
  WebTransport::StreamReadHandle* bidiReadHandle = nullptr;
  WebTransport::StreamWriteHandle* bidiWriteHandle = nullptr;
  WebTransport::StreamWriteHandle* uniWriteHandle = nullptr;

  // Create an incoming uni stream
  EXPECT_CALL(*handler_, onNewUniStream(_))
      .WillOnce([&](WebTransport::StreamReadHandle* handle) {
        uniReadHandle = handle;
      });
  socketDriver_.addReadEvent(2, nullptr, false);
  eventBase_.loopOnce();
  ASSERT_NE(uniReadHandle, nullptr);

  // Create an incoming bidi stream
  EXPECT_CALL(*handler_, onNewBidiStream(_))
      .WillOnce([&](WebTransport::BidiStreamHandle handle) {
        bidiReadHandle = handle.readHandle;
        bidiWriteHandle = handle.writeHandle;
      });
  socketDriver_.addReadEvent(0, nullptr, false);
  eventBase_.loopOnce();
  ASSERT_NE(bidiReadHandle, nullptr);
  ASSERT_NE(bidiWriteHandle, nullptr);

  // Create an outgoing uni stream
  auto uniStreamRes = webTransport()->createUniStream();
  ASSERT_TRUE(uniStreamRes.hasValue());
  uniWriteHandle = uniStreamRes.value();

  // Start pending read operations
  bool uniReadErrorReceived = false;
  bool bidiReadErrorReceived = false;

  // When transport is deleted with no code passed to close, it passes 0 to the
  // transport. Currently the transport echoes this error back to all open
  // streams, hence ex->error is 0.
  auto uniReadFuture = uniReadHandle->readStreamData();
  std::move(uniReadFuture)
      .via(&eventBase_)
      .thenTry([&](folly::Try<WebTransport::StreamData> result) {
        EXPECT_TRUE(result.hasException());
        auto* ex = result.tryGetExceptionObject<WebTransport::Exception>();
        ASSERT_NE(ex, nullptr);
        EXPECT_EQ(ex->error, 0);
        uniReadErrorReceived = true;
      });

  auto bidiReadFuture = bidiReadHandle->readStreamData();
  std::move(bidiReadFuture)
      .via(&eventBase_)
      .thenTry([&](folly::Try<WebTransport::StreamData> result) {
        EXPECT_TRUE(result.hasException());
        auto* ex = result.tryGetExceptionObject<WebTransport::Exception>();
        ASSERT_NE(ex, nullptr);
        EXPECT_EQ(ex->error, 0);
        bidiReadErrorReceived = true;
      });

  // Store cancellation tokens before destroying transport
  auto uniReadToken = uniReadHandle->getCancelToken();
  auto bidiReadToken = bidiReadHandle->getCancelToken();
  auto bidiWriteToken = bidiWriteHandle->getCancelToken();
  auto uniWriteToken = uniWriteHandle->getCancelToken();

  // Verify cancellation tokens are not yet cancelled
  EXPECT_FALSE(uniReadToken.isCancellationRequested());
  EXPECT_FALSE(bidiReadToken.isCancellationRequested());
  EXPECT_FALSE(bidiWriteToken.isCancellationRequested());
  EXPECT_FALSE(uniWriteToken.isCancellationRequested());

  // Destroy the transport - this triggers terminateSessionStreams in the
  // destructor which calls stopReadingWebTransportIngress with an error code.
  // The mock socket will transition stream states to ERROR, allowing clean
  // shutdown.
  transport_.reset();
  eventBase_.loop();

  // Verify all cancellation tokens are now cancelled
  EXPECT_TRUE(uniReadToken.isCancellationRequested());
  EXPECT_TRUE(bidiReadToken.isCancellationRequested());
  EXPECT_TRUE(bidiWriteToken.isCancellationRequested());
  EXPECT_TRUE(uniWriteToken.isCancellationRequested());

  // Verify pending reads received errors
  EXPECT_TRUE(uniReadErrorReceived);
  EXPECT_TRUE(bidiReadErrorReceived);
}

// TODO:
//
// new streams with no handler
// receive connection end
// create bidi fails
// create uni fails
// writeChain fails
// write blocks
// reset fails
// pause/resume
// setReadCallback fails
// sendDatagram fails
// close with error
// await uni/bidi stream credit
